# training without context code # set into train mode nn_model.train() for ep in range(n_epoch): print(f'epoch {ep}') # linearly decay learning rate optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch) pbar = tqdm(dataloader, mininterval=2 ) for x, _ in pbar: # x: images optim.zero_grad() x = x.to(device) # perturb data noise = torch.randn_like(x) t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) x_pert = perturb_input(x, t, noise) # use network to recover noise pred_noise = nn_model(x_pert, t / timesteps) # loss is mean squared error between the predicted and true noise loss = F.mse_loss(pred_noise, noise) loss.backward() optim.step() # save model periodically if ep%4==0 or ep == int(n_epoch-1): if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save(nn_model.state_dict(), save_dir + f"model_{ep}.pth") print('saved model at ' + save_dir + f"model_{ep}.pth")